4c309f
@@ -14,10 +14,9 @@
import javax.servlet.http.HttpServletRequestWrapper;
 import org.springframework.util.AntPathMatcher;
 
 /**
- * Validation filter for OAuth 2.0 endpoints. Ensures that clients get a 40* response for an invalid request, not a
- * regular 302 from Spring Security authentication filters. The filter also ensures that the endpoints request mapping
- * matches the incoming request, if it matches the provided {@link #setAuthorizationEndpointUrl(String)
- * authorizationEndpointUrl} or {@link #setTokenEndpointUrl(String) tokenEndpointUrl}.
+ * This filter ensures that the endpoints request mapping matches the incoming request, if it matches the provided
+ * {@link #setAuthorizationEndpointUrl(String) authorizationEndpointUrl} or {@link #setTokenEndpointUrl(String)
+ * tokenEndpointUrl}.
  * 
  * @author Dave Syer
  */
@@ -30,7 +29,7 @@
public class EndpointValidationFilter implements Filter {
 	private String authorizationEndpointUrl = DEFAULT_AUTHORIZATION_ENDPOINT_URL;
 
 	private String tokenEndpointUrl = DEFAULT_TOKEN_ENDPOINT_URL;
-	
+
 	private AntPathMatcher matcher = new AntPathMatcher();
 
 	public void destroy() {
@@ -41,7 +40,8 @@
public class EndpointValidationFilter implements Filter {
 		HttpServletRequest servletRequest = (HttpServletRequest) request;
 		if (matches(servletRequest, authorizationEndpointUrl)) {
 			servletRequest = wrapRequest(servletRequest, DEFAULT_AUTHORIZATION_ENDPOINT_URL);
-		} else if (matches(servletRequest, tokenEndpointUrl)) {
+		}
+		else if (matches(servletRequest, tokenEndpointUrl)) {
 			servletRequest = wrapRequest(servletRequest, DEFAULT_TOKEN_ENDPOINT_URL);
 		}
 		chain.doFilter(servletRequest, response);
@@ -52,14 +52,27 @@
public class EndpointValidationFilter implements Filter {
 			return request;
 		}
 		return new HttpServletRequestWrapper(request) {
+			private String requestUri = prependContextPath(request, urlToMatch);
+			private String originalRequestUri = request.getRequestURI();
+			private String originalServletPath = request.getServletPath();
 			@Override
 			public String getRequestURI() {
-				return prependContextPath(request, urlToMatch);
+				String standard = super.getRequestURI();
+				if (standard.equals(originalRequestUri)) {
+					// If a request forward is being dispatched the request URI changes 
+					return requestUri;
+				}
+				return standard;
 			}
 
 			@Override
 			public String getServletPath() {
-				return urlToMatch;
+				String standard = super.getServletPath();
+				if (standard.equals(originalServletPath)) {
+					// If a request forward is being dispatched the servlet path changes 
+					return urlToMatch;
+				}
+				return super.getServletPath();
 			}
 		};
 	}
